###########################################################################################################################
# GOV2001
# Final Papaer
# Martin Saveski and Abdullah Almaatouq
# May 4, 2015
###########################################################################################################################

setwd("~/Dropbox (MIT)/GOV2001/Product space/extension")

###########################################################################################################################

load_data <- function(start_year, end_year, step, one_pair=F) {  
  vars_to_keep <- c("gdp_ppp")
  
  df_all <- data.frame()
  country_features <- read.csv("../data/country_features.csv", header = T)
  eci_all_years <- read.csv("../data/ECI data/eci_all_years.csv", header = T)
  
  for(year in seq(start_year, (end_year - step), by=step)) {
    country_features_year <- data.frame(country_features$code)
    names(country_features_year) <- c("code")
    
    for (var in vars_to_keep) {
      f_name <- sprintf("%s_%d", var, year)
      country_features_year[, var] <- country_features[, f_name]
    }
    country_features_year$growth <- (country_features[, sprintf("gdp_ppp_%d", year + step)] / country_features[, sprintf("gdp_ppp_%d", year)])
    country_features_year <- country_features_year[complete.cases(country_features_year), ]
    
    # load the ECI scores for this year
    eci_year <- eci_all_years[eci_all_years$year == year, ]
    
    # merge ECI + Country info
    df_year <- merge(country_features_year, eci_year[, c("code", "eci")])
    
    # append to the full matrix
    df_all <- rbind(df_all, df_year)
    
    # don't load all pairs
    if (one_pair == T) {
      break  
    }
  }

  return(df_all)
}

###########################################################################################################################

cross.validation <- function(df, nfolds, nrep) {
  set.seed(0)
  
  all_rmse <- c()
  for(r in 1:nrep) {
    n <- nrow(df)
    
    cv.ind <- rep(1:nfolds, length.out=n)
    cv.ind <- sample(cv.ind, replace=FALSE)
    preds <- rep(NA, n)
    
    for (i in 1:nfolds) {
      fold.ind <- which(cv.ind == i)
      df.train <- df[-fold.ind,]
      df.test <- df[fold.ind,]
      
      lm.reg <- lm(growth ~ eci, data = df.train)
      preds[fold.ind] <- predict(lm.reg, newdata=df.test)
    }
    
    rmse <- sqrt(mean((df$growth - preds)^2))
    all_rmse <- c(all_rmse, rmse)
  }
  
  return(list(mean = mean(all_rmse), sd = sd(all_rmse)))
}

###########################################################################################################################

years_in_future <- 1:20
means <- c()
sds <- c()
growth <- c()
cors <- c()
for(step in years_in_future) {
  cat("step: ", step, "\n")
  df <- load_data(1985, 2005, step=step, one_pair=F)
  df$growth <- log(df$growth)  
  cv.res <- cross.validation(df, nfolds=10, nrep=20)
  
  means <- c(means, cv.res$mean)
  sds <- c(sds, cv.res$sd)
  growth <- c(growth, median(df$growth))
  cors <- c(cors, cor(df$eci, df$growth))
}

# table
rbind(years_in_future, exp(means), exp(means) / exp(growth))

# LINES
plot(years_in_future, cors, type="l")

# p1
pdf("figures/CV_rmse_growth.pdf", width=7, height=6.5)
plot(years_in_future, exp(means), type="l", ylim=c(1, 2.3), col="red", xlab="Time increment in years", ylab="")
lines(years_in_future, exp(growth))
legend("topleft",col=c("black", "red"),lty=1,legend=c("Average Growth", "Root Mean Squared Error"))
dev.off()

# plot with 2 axes
par(mar=c(5,4,4,5)+.1)
plot(years_in_future, exp(means), type="l", ylim=c(1, 2.3), col="red", xlab="Year in the Future", ylab="Root Mean Squared Error")
par(new=TRUE)
plot(years_in_future, exp(growth),type="l", ylim=c(1, 2.3), xaxt="n",yaxt="n",xlab="",ylab="")
axis(4)
mtext("y2",side=4,line=3)
legend("topleft",col=c("red","black"),lty=1,legend=c("Root Mean Squared Error", "Average Growth"))


# p2
pdf("figures/CV_cv_rmse.pdf", width=7, height=6.5)
plot(years_in_future, exp(means) / exp(growth), type="l", ylim=c(0.5, 1), col="blue", xlab="Time increment in years", ylab="CV(RMSE)")
dev.off()

# ERROR BARS
errbar(years_in_future, exp(means), exp(means + sds), exp(means - sds), 
       type="l", ylim=c(1, 2.3), col="red",
       xlab="Year in the Future", ylab="Root Mean Squared Error")


jpeg("cv_all_year_pairs_ECI_ONLY.jpg", width=800, height=600)
errbar(years_in_future, means, means + sds, means - sds, type="l", xlab="Year in the Future", ylab="Root Mean Squared Error", col="red")
dev.off()

# END